use dir_writer::{FileCollector, GeneratorArgs, IntermediateRepr, LanguageFeatures};
use functions::{
    render_async_client, render_runtime, render_source_files, render_sync_client, render_type_map,
};
use generated_types::render_py_types;

use crate::{
    functions::{render_config, render_globals, render_init, render_parser, render_tracing},
    generated_types::{
        render_py_stream_types_utils, render_py_type_builder, render_py_types_utils,
    },
};

mod functions;
mod generated_types;
mod ir_to_py;
mod package;
mod r#type;
mod utils;
mod watchers;

#[derive(Default)]
pub struct PyLanguageFeatures;

impl LanguageFeatures for PyLanguageFeatures {
    const CONTENT_PREFIX: &'static str = r#"
# ----------------------------------------------------------------------------
#
#  Welcome to Baml! To use this generated code, please run the following:
#
#  $ pip install baml
#
# ----------------------------------------------------------------------------

# This file was generated by BAML: please do not edit it. Instead, edit the
# BAML files and re-generate this code using: baml-cli generate
# baml-cli is available with the baml package.
        "#;

    fn name() -> &'static str {
        "python/pydantic"
    }

    fn generate_sdk_files(
        &self,
        collector: &mut FileCollector<Self>,
        ir: std::sync::Arc<IntermediateRepr>,
        args: &GeneratorArgs,
    ) -> Result<(), anyhow::Error> {
        let pkg = package::CurrentRenderPackage::new(
            "baml_client",
            ir.clone(),
            args.is_pydantic_2.unwrap_or(true),
        );
        let file_map = args.file_map_as_json_string()?;

        // Build function name map for event collectors
        let expr_fn_wrappers = ir.expr_fns_as_functions();
        let mut function_name_map = std::collections::HashMap::new();
        for func in ir.functions.iter() {
            let py_fn = ir_to_py::functions::ir_function_to_py(func, &pkg);
            function_name_map.insert(func.elem.name().to_string(), py_fn.name.clone());
        }
        for func in expr_fn_wrappers.iter() {
            let py_fn = ir_to_py::functions::ir_function_to_py(func, &pkg);
            function_name_map.insert(func.elem.name().to_string(), py_fn.name.clone());
        }

        // Generate event collectors first to determine has_events
        let notification_collectors =
            watchers::build_notification_collectors(args, &pkg, &function_name_map)?;
        collector.add_file(
            "watchers.py",
            watchers::render_events(&notification_collectors)?,
        )?;

        // Generate __init__.py with correct has_events value
        collector.add_file("__init__.py", render_init(&pkg, &args.default_client_mode)?)?;
        collector.add_file("inlinedbaml.py", render_source_files(file_map)?)?;
        collector.add_file("runtime.py", render_runtime(&pkg)?)?;
        collector.add_file("tracing.py", render_tracing(&pkg)?)?;
        collector.add_file("globals.py", render_globals(&pkg)?)?;
        collector.add_file("config.py", render_config(&pkg)?)?;

        let functions = ir
            .functions
            .iter()
            .map(|f| ir_to_py::functions::ir_function_to_py(f, &pkg))
            .chain(
                ir.expr_fns
                    .iter()
                    .map(|f| ir_to_py::functions::ir_expr_fn_to_py(f, &pkg)),
            )
            .collect::<Vec<_>>();
        collector.add_file("async_client.py", render_async_client(&functions, &pkg)?)?;

        collector.add_file("sync_client.py", render_sync_client(&functions, &pkg)?)?;
        collector.add_file("parser.py", render_parser(&functions, &pkg)?)?;

        let py_classes = ir
            .walk_classes()
            .map(|c| ir_to_py::classes::ir_class_to_py(c.item, &pkg))
            .collect::<Vec<_>>();
        let enums = ir
            .walk_enums()
            .map(|e| ir_to_py::enums::ir_enum_to_py(e.item, &pkg))
            .collect::<Vec<_>>();
        let type_aliases = ir.walk_type_aliases().collect::<Vec<_>>();

        let mut py_type_aliases = type_aliases
            .iter()
            .map(|c| ir_to_py::type_aliases::ir_type_alias_to_py(c.item, &pkg))
            .collect::<Vec<_>>();
        py_type_aliases.sort_by(|a, b| a.name.cmp(&b.name));

        pkg.set("baml_client.type_map");
        collector.add_file("type_map.py", render_type_map(&py_classes, &enums)?)?;

        pkg.set("baml_client.type_builder");

        collector.add_file(
            "type_builder.py",
            render_py_type_builder(&py_classes, &enums)?,
        )?;

        pkg.set("baml_client.types");
        collector.add_file("types.py", render_py_types_utils(&pkg)?)?;
        collector.append_to_file("types.py", &render_py_types(&enums, &pkg)?)?;
        collector.append_to_file("types.py", &render_py_types(&py_classes, &pkg)?)?;
        collector.append_to_file("types.py", &render_py_types(&py_type_aliases, &pkg)?)?;

        let mut py_stream_type_aliases = type_aliases
            .iter()
            .map(|c| ir_to_py::type_aliases::ir_type_alias_to_py_stream(c.item, &pkg))
            .collect::<Vec<_>>();
        py_stream_type_aliases.sort_by(|a, b| a.name.cmp(&b.name));

        let py_classes = ir
            .walk_classes()
            .map(|c| ir_to_py::classes::ir_class_to_py_stream(c.item, &pkg))
            .collect::<Vec<_>>();

        pkg.set("baml_client.stream_types");
        collector.add_file("stream_types.py", render_py_stream_types_utils(&pkg)?)?;
        collector.append_to_file("stream_types.py", &render_py_types(&py_classes, &pkg)?)?;
        collector.append_to_file(
            "stream_types.py",
            &render_py_types(&py_stream_type_aliases, &pkg)?,
        )?;

        Ok(())
    }
}

#[cfg(test)]
mod python_tests {
    use test_harness::{create_code_gen_test_suites, TestLanguageFeatures};

    impl TestLanguageFeatures for crate::PyLanguageFeatures {
        fn test_name() -> &'static str {
            "python"
        }
    }

    create_code_gen_test_suites!(crate::PyLanguageFeatures);
}

#[cfg(test)]
mod tests {
    #[test]
    fn test_name() {
        use std::str::FromStr;

        use dir_writer::LanguageFeatures;

        let gen_type = baml_types::GeneratorOutputType::from_str(crate::PyLanguageFeatures::name())
            .expect("PyLanguageFeatures name should be a valid GeneratorOutputType");
        assert_eq!(gen_type, baml_types::GeneratorOutputType::PythonPydantic);
    }
}
